{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "ONN_Usage_Example.ipynb",
      "version": "0.3.2",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ugTEiBmBV3Cc",
        "colab_type": "text"
      },
      "source": [
        "#Installing Dependencies"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "FbEX3YFwKCxy",
        "colab_type": "code",
        "outputId": "ce12c48c-690e-49cb-c8a1-b31f66872022",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 85
        }
      },
      "source": [
        "!pip install --upgrade onn"
      ],
      "execution_count": 7,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Requirement already up-to-date: onn in /usr/local/lib/python3.6/dist-packages (0.1.8)\n",
            "Requirement already satisfied, skipping upgrade: torch in /usr/local/lib/python3.6/dist-packages (from onn) (1.1.0)\n",
            "Requirement already satisfied, skipping upgrade: numpy in /usr/local/lib/python3.6/dist-packages (from onn) (1.16.4)\n",
            "Requirement already satisfied, skipping upgrade: mabalgs in /usr/local/lib/python3.6/dist-packages (from onn) (0.6.4)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "CCBv8CqCe5Ao",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 102
        },
        "outputId": "a9c5bc86-af25-4b3a-b827-a5e22c1b2a1d"
      },
      "source": [
        "!pip install -U imbalanced-learn"
      ],
      "execution_count": 8,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Requirement already up-to-date: imbalanced-learn in /usr/local/lib/python3.6/dist-packages (0.5.0)\n",
            "Requirement already satisfied, skipping upgrade: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from imbalanced-learn) (0.13.2)\n",
            "Requirement already satisfied, skipping upgrade: numpy>=1.11 in /usr/local/lib/python3.6/dist-packages (from imbalanced-learn) (1.16.4)\n",
            "Requirement already satisfied, skipping upgrade: scipy>=0.17 in /usr/local/lib/python3.6/dist-packages (from imbalanced-learn) (1.3.0)\n",
            "Requirement already satisfied, skipping upgrade: scikit-learn>=0.21 in /usr/local/lib/python3.6/dist-packages (from imbalanced-learn) (0.21.2)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "XbYfcIpkV-SA",
        "colab_type": "text"
      },
      "source": [
        "##Importing Dependencies"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "9xT-ZdVsKH6z",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        },
        "outputId": "899e5370-ad08-40cc-8241-9a577717d828"
      },
      "source": [
        "from onn.OnlineNeuralNetwork import ONN\n",
        "from onn.OnlineNeuralNetwork import ONN_THS\n",
        "from sklearn.datasets import make_classification, make_circles\n",
        "from sklearn.model_selection import train_test_split\n",
        "import torch\n",
        "from sklearn.metrics import accuracy_score, balanced_accuracy_score\n",
        "from imblearn.datasets import make_imbalance\n",
        "import numpy as np"
      ],
      "execution_count": 34,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Using TensorFlow backend.\n"
          ],
          "name": "stderr"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "N9H22wxBWB9f",
        "colab_type": "text"
      },
      "source": [
        "## Initializing Network"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "1wOHgHL1LieT",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "onn_network = ONN(features_size=10, max_num_hidden_layers=5, qtd_neuron_per_hidden_layer=40, n_classes=10)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "eYqkIyxyWI2h",
        "colab_type": "text"
      },
      "source": [
        "##Creating Fake Classification Dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "WXgNSF9gL69F",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "X, Y = make_classification(n_samples=50000, n_features=10, n_informative=4, n_redundant=0, n_classes=10,\n",
        "                           n_clusters_per_class=1, class_sep=3)\n",
        "X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.3, random_state=42, shuffle=True)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "THGPFSWJWPEm",
        "colab_type": "text"
      },
      "source": [
        "##Learning and predicting at the same time"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "70J3ZYtmL-Zm",
        "colab_type": "code",
        "outputId": "15babeef-1128-44b7-8dcf-92d4ef3d295a",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        }
      },
      "source": [
        "for i in range(len(X_train)):\n",
        "  onn_network.partial_fit(np.asarray([X_train[i, :]]), np.asarray([y_train[i]]))\n",
        "  \n",
        "  if i % 1000 == 0:\n",
        "    predictions = onn_network.predict(X_test)\n",
        "    print(\"Online Accuracy: {}\".format(balanced_accuracy_score(y_test, predictions)))"
      ],
      "execution_count": 25,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Online Accuracy: 0.14337461746314914\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.83890086 0.04027476 0.04027476 0.04027476 0.04027476]\n",
            "Training Loss: 1.296051\n",
            "Online Accuracy: 0.9606844639729234\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.839583   0.04010423 0.04010423 0.04010423 0.04010423]\n",
            "Training Loss: 0.42200255\n",
            "Online Accuracy: 0.9587807527275185\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8393374  0.04016563 0.04016563 0.04016563 0.04016563]\n",
            "Training Loss: 0.3702089\n",
            "Online Accuracy: 0.967292853621438\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.83981675 0.0400458  0.0400458  0.0400458  0.0400458 ]\n",
            "Training Loss: 0.26317188\n",
            "Online Accuracy: 0.9654591078178111\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.83988035 0.04002989 0.04002989 0.04002989 0.04002989]\n",
            "Training Loss: 0.25282928\n",
            "Online Accuracy: 0.9665110941847939\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8398668  0.04003327 0.04003327 0.04003327 0.04003327]\n",
            "Training Loss: 0.21191452\n",
            "Online Accuracy: 0.9710897029043102\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.83997285 0.04000676 0.04000676 0.04000676 0.04000676]\n",
            "Training Loss: 0.21445769\n",
            "Online Accuracy: 0.9719590088952164\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8397224  0.04006939 0.04006939 0.04006939 0.04006939]\n",
            "Training Loss: 0.19888349\n",
            "Online Accuracy: 0.9704734746855797\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399309  0.04001725 0.04001725 0.04001725 0.04001725]\n",
            "Training Loss: 0.19780262\n",
            "Online Accuracy: 0.9682571684488185\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8398869  0.04002826 0.04002826 0.04002826 0.04002826]\n",
            "Training Loss: 0.21107836\n",
            "Online Accuracy: 0.9717444211680417\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399914  0.04000212 0.04000212 0.04000212 0.04000212]\n",
            "Training Loss: 0.18996303\n",
            "Online Accuracy: 0.9731342027566192\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.839978   0.04000548 0.04000548 0.04000548 0.04000548]\n",
            "Training Loss: 0.19749445\n",
            "Online Accuracy: 0.9732475060493468\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.839696   0.04007598 0.04007598 0.04007598 0.04007598]\n",
            "Training Loss: 0.19404268\n",
            "Online Accuracy: 0.9723553377085323\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8398478  0.04003802 0.04003802 0.04003802 0.04003802]\n",
            "Training Loss: 0.18190624\n",
            "Online Accuracy: 0.9743039644477742\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.83997566 0.04000606 0.04000606 0.04000606 0.04000606]\n",
            "Training Loss: 0.1828396\n",
            "Online Accuracy: 0.9738457498877151\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8392171 0.0401957 0.0401957 0.0401957 0.0401957]\n",
            "Training Loss: 0.19953582\n",
            "Online Accuracy: 0.9732231868948086\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399872  0.04000318 0.04000318 0.04000318 0.04000318]\n",
            "Training Loss: 0.16184519\n",
            "Online Accuracy: 0.9735012003132812\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8394533  0.04013667 0.04013667 0.04013667 0.04013667]\n",
            "Training Loss: 0.1797122\n",
            "Online Accuracy: 0.9746659982580329\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399609  0.04000976 0.04000976 0.04000976 0.04000976]\n",
            "Training Loss: 0.19841474\n",
            "Online Accuracy: 0.9751102633737601\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8398771 0.0400307 0.0400307 0.0400307 0.0400307]\n",
            "Training Loss: 0.1930369\n",
            "Online Accuracy: 0.9740513941554421\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399562  0.04001092 0.04001092 0.04001092 0.04001092]\n",
            "Training Loss: 0.21371904\n",
            "Online Accuracy: 0.9755335557702551\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399458  0.04001352 0.04001352 0.04001352 0.04001352]\n",
            "Training Loss: 0.19402196\n",
            "Online Accuracy: 0.9758116920399523\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.839987   0.04000323 0.04000323 0.04000323 0.04000323]\n",
            "Training Loss: 0.17364302\n",
            "Online Accuracy: 0.9754555741342109\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.83992887 0.04001775 0.04001775 0.04001775 0.04001775]\n",
            "Training Loss: 0.14595963\n",
            "Online Accuracy: 0.9754666701388206\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399415  0.04001459 0.04001459 0.04001459 0.04001459]\n",
            "Training Loss: 0.16576274\n",
            "Online Accuracy: 0.9737672492447718\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399788  0.04000528 0.04000528 0.04000528 0.04000528]\n",
            "Training Loss: 0.17223129\n",
            "Online Accuracy: 0.9755318323234476\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8398563 0.0400359 0.0400359 0.0400359 0.0400359]\n",
            "Training Loss: 0.1655618\n",
            "Online Accuracy: 0.9763829709632855\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.83996785 0.04000802 0.04000802 0.04000802 0.04000802]\n",
            "Training Loss: 0.15412873\n",
            "Online Accuracy: 0.976408073580625\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.83986807 0.04003297 0.04003297 0.04003297 0.04003297]\n",
            "Training Loss: 0.19262518\n",
            "Online Accuracy: 0.9757346501029929\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.83998775 0.04000303 0.04000303 0.04000303 0.04000303]\n",
            "Training Loss: 0.16019145\n",
            "Online Accuracy: 0.9764235355129427\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399023  0.04002441 0.04002441 0.04002441 0.04002441]\n",
            "Training Loss: 0.15920378\n",
            "Online Accuracy: 0.9768755609369231\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.83994466 0.04001381 0.04001381 0.04001381 0.04001381]\n",
            "Training Loss: 0.12754413\n",
            "Online Accuracy: 0.9754246065697441\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.83994985 0.04001251 0.04001251 0.04001251 0.04001251]\n",
            "Training Loss: 0.1729848\n",
            "Online Accuracy: 0.9768689984007384\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.83995646 0.04001086 0.04001086 0.04001086 0.04001086]\n",
            "Training Loss: 0.21214448\n",
            "Online Accuracy: 0.9769247559047265\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.83999026 0.04000241 0.04000241 0.04000241 0.04000241]\n",
            "Training Loss: 0.133252\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_K01fvy7Tj7Z",
        "colab_type": "text"
      },
      "source": [
        "# Learning in batch with CUDA"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Jg1UJWvDTjJc",
        "colab_type": "code",
        "outputId": "b917847a-45a3-445f-c527-3d449e1de377",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "onn_network = ONN(features_size=10, max_num_hidden_layers=5, qtd_neuron_per_hidden_layer=40, n_classes=10, batch_size=10, use_cuda=True)"
      ],
      "execution_count": 26,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Using CUDA :]\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Uoxk5Jk_UYKJ",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "from torch.utils.data import Dataset, DataLoader\n",
        "class Dataset(Dataset):\n",
        "\n",
        "  def __init__(self, X, Y):\n",
        "    self.X = X\n",
        "    self.Y = Y\n",
        "\n",
        "  def __len__(self):\n",
        "      return len(self.X)\n",
        "\n",
        "  def __getitem__(self, idx):\n",
        "      X = self.X[idx],\n",
        "      Y = self.Y[idx]\n",
        "\n",
        "      return X, Y\n",
        "    \n",
        "transformed_dataset = Dataset(X_train, y_train)\n",
        "dataloader = DataLoader(transformed_dataset, batch_size=10,shuffle=True, num_workers=1)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "8jz4S-SVUAMx",
        "colab_type": "code",
        "outputId": "eb25d974-e8af-4e51-e9f0-5f2501b02b56",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 170
        }
      },
      "source": [
        "for local_X, local_y in dataloader: \n",
        "  onn_network.partial_fit(np.squeeze(torch.stack(local_X).numpy()), local_y.numpy())"
      ],
      "execution_count": 28,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.83902884 0.04024278 0.04024278 0.04024278 0.04024278]\n",
            "Training Loss: 1.577592\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8394673  0.04013318 0.04013318 0.04013318 0.04013318]\n",
            "Training Loss: 0.5495532\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8397412 0.0400647 0.0400647 0.0400647 0.0400647]\n",
            "Training Loss: 0.40501082\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "u7XwSO6zVpsz",
        "colab_type": "code",
        "outputId": "6f19825b-4435-442c-c710-4003ef7b555c",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 34
        }
      },
      "source": [
        "predictions = onn_network.predict(X_test)\n",
        "print(\"Accuracy: {}\".format(balanced_accuracy_score(y_test, predictions)))"
      ],
      "execution_count": 29,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Accuracy: 0.9517950352535276\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DSk4Fl4GV9NG",
        "colab_type": "text"
      },
      "source": [
        "#Using contextual bandit - ONN_THS"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "vBweUP_CZaFz",
        "colab_type": "text"
      },
      "source": [
        "In this example the ONN acts like a contextual bandits a reinforcement learning algorithm type. "
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "TeFBn4erUDY3",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "X_linear, Y_linear = make_classification(n_samples=10000, n_features=2, n_informative=2, n_redundant=0, n_classes=2, n_clusters_per_class=1, class_sep=200, shuffle=True)\n",
        "X_non_linear, Y_non_linear = make_circles(n_samples=10000, noise=0.1, factor=0.3, shuffle=True)\n",
        "X_linear_2, Y_linear_2 = make_classification(n_samples=10000, n_features=2, n_informative=2, n_redundant=0, n_classes=2, n_clusters_per_class=1, class_sep=200, shuffle=True)\n",
        "\n",
        "X_linear_train = X_linear[:5000]\n",
        "Y_linear_train = Y_linear[:5000]\n",
        "\n",
        "X_linear_test = X_linear[5000:]\n",
        "Y_linear_test = Y_linear[5000:]\n",
        "\n",
        "X_non_linear_train = X_non_linear[:5000]\n",
        "Y_non_linear_train = Y_non_linear[:5000]\n",
        "\n",
        "X_non_linear_test = X_non_linear[5000:]\n",
        "Y_non_linear_test = Y_non_linear[5000:]\n",
        "\n",
        "X_linear_train_2 = X_linear_2[:5000]\n",
        "Y_linear_train_2 = Y_linear_2[:5000]\n",
        "\n",
        "X_linear_test_2 = X_linear_2[5000:]\n",
        "Y_linear_test_2 = Y_linear_2[5000:]"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "lfnN1xG7WtF7",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "gp = ONN_THS(2, 5, 100, 2)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "7EmKF-y8Wytt",
        "colab_type": "code",
        "outputId": "b6b7b943-d466-4274-cd11-e1325290412b",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        }
      },
      "source": [
        "for epoch in range(5):\n",
        "\n",
        "    for i in range(len(X_linear_train)):\n",
        "        x = np.asarray([X_linear_train[i, :]])\n",
        "        y = np.asarray([Y_linear_train[i]])\n",
        "\n",
        "        arm, exp = gp.predict(x)\n",
        "        \n",
        "        if arm == y[0]:  \n",
        "          gp.partial_fit(x, y, exp)\n",
        "          \n",
        "        if i % 2000 == 1999:\n",
        "          pred = []\n",
        "          print(\"======================================================\")\n",
        "          for i in range(len(X_linear_test)):  \n",
        "            pred.append(gp.predict(np.asarray([X_linear_test[i, :]]))[0])\n",
        "          print(\"Accuracy: \" + str(balanced_accuracy_score(Y_linear_test, pred)))\n",
        "          print(\"======================================================\")\n",
        "\n",
        "print('Finished Training')"
      ],
      "execution_count": 32,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 7.708073e-07\n",
            "======================================================\n",
            "Accuracy: 0.9320135479448636\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.66480255 0.08414558 0.08314086 0.08373952 0.0841715 ]\n",
            "Training Loss: 1.4113159\n",
            "======================================================\n",
            "Accuracy: 0.7803689523703463\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.4597675  0.1319707  0.13887632 0.13379806 0.13558747]\n",
            "Training Loss: 2.1171865\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.8254292\n",
            "======================================================\n",
            "Accuracy: 0.8239034376118035\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 1.9921635\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.09616726 0.22576231 0.22466676 0.22605874 0.2273449 ]\n",
            "Training Loss: 1.2512604\n",
            "======================================================\n",
            "Accuracy: 0.9194272332380344\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.7483906  0.06154997 0.0657564  0.0615586  0.06274442]\n",
            "Training Loss: 1.2462758\n",
            "======================================================\n",
            "Accuracy: 0.9089153835864292\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 1.1584326\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 1.0460584\n",
            "======================================================\n",
            "Accuracy: 0.9091254701747447\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.04581921 0.22431232 0.26757908 0.22434013 0.23794925]\n",
            "Training Loss: 1.4130515\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.00093426596\n",
            "======================================================\n",
            "Accuracy: 0.9230967883152815\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 1.7778293\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.9200642132137263\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.1015042  0.22482882 0.22428374 0.22482881 0.22455448]\n",
            "Training Loss: 1.1210397\n",
            "======================================================\n",
            "Accuracy: 0.9140752253990685\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.00011777431\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.13615824 0.21326509 0.22357734 0.21326567 0.21373361]\n",
            "Training Loss: 1.4623655\n",
            "======================================================\n",
            "Accuracy: 0.9162808945688536\n",
            "======================================================\n",
            "Finished Training\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "G6myrkQjW1kj",
        "colab_type": "code",
        "outputId": "873af04f-482d-4ae5-89fe-8df04f21cc38",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        }
      },
      "source": [
        "for epoch in range(5):\n",
        "\n",
        "    for i in range(len(X_non_linear_train)):\n",
        "        x = np.asarray([X_non_linear_train[i, :]])\n",
        "        y = np.asarray([Y_non_linear_train[i]])\n",
        "\n",
        "        arm, exp = gp.predict(x)\n",
        "        \n",
        "        if arm == y[0]:  \n",
        "          gp.partial_fit(x, y, exp)\n",
        "          \n",
        "        if i % 2000 == 1999:\n",
        "          pred = []\n",
        "          print(\"======================================================\")\n",
        "          for i in range(len(X_linear_test)):  \n",
        "            pred.append(gp.predict(np.asarray([X_non_linear_test[i, :]]))[0])\n",
        "          print(\"Accuracy: \" + str(balanced_accuracy_score(Y_non_linear_test, pred)))\n",
        "          print(\"======================================================\")\n",
        "\n",
        "print('Finished Training')"
      ],
      "execution_count": 33,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.7324713  0.06278172 0.07632934 0.0627818  0.06563586]\n",
            "Training Loss: 0.565434\n",
            "======================================================\n",
            "Accuracy: 0.4978415196546431\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.83837765 0.04006734 0.04134858 0.04006734 0.04013911]\n",
            "Training Loss: 0.32934776\n",
            "======================================================\n",
            "Accuracy: 0.6431012228961956\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.83944225 0.04013941 0.04013941 0.04013941 0.04013941]\n",
            "Training Loss: 0.27433354\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.83982235 0.04004439 0.04004439 0.04004439 0.04004439]\n",
            "Training Loss: 0.2834035\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.83984315 0.04003919 0.04003919 0.04003919 0.04003919]\n",
            "Training Loss: 0.2099059\n",
            "======================================================\n",
            "Accuracy: 0.7659370825499332\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8395842  0.04010393 0.04010393 0.04010393 0.04010393]\n",
            "Training Loss: 0.21292916\n",
            "======================================================\n",
            "Accuracy: 0.9151972664315626\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.83984536 0.04003864 0.04003864 0.04003864 0.04003864]\n",
            "Training Loss: 0.19620447\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8398821  0.04002946 0.04002946 0.04002946 0.04002946]\n",
            "Training Loss: 0.16505778\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.83968943 0.04007762 0.04007762 0.04007762 0.04007762]\n",
            "Training Loss: 0.14461489\n",
            "======================================================\n",
            "Accuracy: 0.9197904671664747\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399311  0.04001722 0.04001722 0.04001722 0.04001722]\n",
            "Training Loss: 0.124461606\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8396594  0.04008514 0.04008514 0.04008514 0.04008514]\n",
            "Training Loss: 0.115024015\n",
            "======================================================\n",
            "Accuracy: 0.9301949488311918\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8398307  0.04004231 0.04004231 0.04004231 0.04004231]\n",
            "Training Loss: 0.10779811\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8398185  0.04004535 0.04004535 0.04004535 0.04004535]\n",
            "Training Loss: 0.099823594\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399364  0.04001588 0.04001588 0.04001588 0.04001588]\n",
            "Training Loss: 0.092471905\n",
            "======================================================\n",
            "Accuracy: 0.8999838239974118\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.839942   0.04001448 0.04001448 0.04001448 0.04001448]\n",
            "Training Loss: 0.08162724\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.83992773 0.04001804 0.04001804 0.04001804 0.04001804]\n",
            "Training Loss: 0.0805736\n",
            "======================================================\n",
            "Accuracy: 0.9387956702073073\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.83965665 0.04008581 0.04008581 0.04008581 0.04008581]\n",
            "Training Loss: 0.081078894\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8397386  0.04006533 0.04006533 0.04006533 0.04006533]\n",
            "Training Loss: 0.072354406\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399574  0.04001063 0.04001063 0.04001063 0.04001063]\n",
            "Training Loss: 0.06984673\n",
            "======================================================\n",
            "Accuracy: 0.9489987918398066\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399927 0.0400018 0.0400018 0.0400018 0.0400018]\n",
            "Training Loss: 0.06360293\n",
            "======================================================\n",
            "Accuracy: 0.9428014308482289\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.83996755 0.0400081  0.0400081  0.0400081  0.0400081 ]\n",
            "Training Loss: 0.06630126\n",
            "Finished Training\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "G27Vh6mI6LTi",
        "colab_type": "text"
      },
      "source": [
        "# Imbalanced Dataset"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "E-IHNagYXgap",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "X, Y = make_classification(n_samples=110000, n_features=20, n_classes=10, n_informative=8, n_redundant=0, n_clusters_per_class=1, class_sep=900)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "O5Mh3jR-6TWR",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "X_t, Y_t = make_imbalance(X, Y, sampling_strategy={0: 800, 1: 5000, 2: 10000, 3: 10000, 4: 1000, 5: 1000, 6: 500, 7: 10000, 8: 5000, 9:5000})"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ryU12keU6WOp",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2)\n",
        "\n",
        "X_train_t, X_test_t, y_train_t, y_test_t = train_test_split(X_t, Y_t, test_size=0.2)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "Si9uRxRu7Ecw",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "gp = ONN_THS(20, 5, 100, 10)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "a21-OBBC6YrK",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "outputId": "9285bd62-10a5-4f71-b8ec-22097ae3a808"
      },
      "source": [
        "for i in range(len(X_train)):\n",
        "    x = np.asarray([X_train[i, :]])\n",
        "    y = np.asarray([y_train[i]])\n",
        "\n",
        "    arm, exp = gp.predict(x)\n",
        "\n",
        "    if arm == y[0]:  \n",
        "      gp.partial_fit(x, y, exp)\n",
        "\n",
        "    if i % 2000 == 1999:\n",
        "      pred = []\n",
        "      print(\"======================================================\")\n",
        "      for i in range(len(X_test)):  \n",
        "        pred.append(gp.predict(np.asarray([X_test[i, :]]))[0])\n",
        "      print(\"Accuracy: \" + str(balanced_accuracy_score(y_test, pred)))\n",
        "      print(\"======================================================\")\n",
        "\n",
        "print('Finished Training')"
      ],
      "execution_count": 49,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "======================================================\n",
            "Accuracy: 0.4353221256829899\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.06869241 0.16809322 0.3725306  0.17724589 0.21343793]\n",
            "Training Loss: 9.325638\n",
            "======================================================\n",
            "Accuracy: 0.4386085948153132\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.11643752 0.20428126 0.23810413 0.20844749 0.23272966]\n",
            "Training Loss: 11.25701\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 2.6811903\n",
            "======================================================\n",
            "Accuracy: 0.6928950497293518\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 10.146265\n",
            "======================================================\n",
            "Accuracy: 0.6838257488564606\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 8.005989\n",
            "======================================================\n",
            "Accuracy: 0.6115088279471166\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.79145914 0.05235159 0.05036381 0.05197832 0.05384716]\n",
            "Training Loss: 3.7819355\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 8.093227\n",
            "======================================================\n",
            "Accuracy: 0.7768358673631839\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.7728323692342458\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.7750311850166554\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.23982884 0.19055171 0.18932813 0.19059701 0.18969429]\n",
            "Training Loss: 13.883859\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.8685216303769551\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.8026281358448346\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 17.250746\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 2.5108016\n",
            "======================================================\n",
            "Accuracy: 0.8674449681421013\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.8713122081601006\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.86805856622274\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.8673653690806932\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.8709340126479518\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 26.645742\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.5200308  0.12016081 0.11971524 0.12033709 0.11975605]\n",
            "Training Loss: 5.1447864\n",
            "======================================================\n",
            "Accuracy: 0.7919937075157115\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 16.11202\n",
            "======================================================\n",
            "Accuracy: 0.8733754651900446\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.8717504131399775\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.873564110770424\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 15.066921\n",
            "======================================================\n",
            "Accuracy: 0.7295902241565401\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 2.8158276\n",
            "======================================================\n",
            "Accuracy: 0.8713796789076305\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 22.253164\n",
            "======================================================\n",
            "Accuracy: 0.8771001088342661\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.53331    0.11533871 0.11965422 0.11555222 0.1161448 ]\n",
            "Training Loss: 13.497225\n",
            "======================================================\n",
            "Accuracy: 0.8746053463042536\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 4.758543\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.8741468052285819\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.8795648644911169\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 16.302721\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 3.1496372\n",
            "======================================================\n",
            "Accuracy: 0.8837012263713927\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.902540151229202\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.9021177354976991\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 8.878667\n",
            "======================================================\n",
            "Accuracy: 0.8079567207927605\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.8039658939000326\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 6.192695\n",
            "======================================================\n",
            "Accuracy: 0.9009930184879561\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 10.917292\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.900008182367175\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.7559461  0.06102359 0.0612076  0.0608501  0.06097254]\n",
            "Training Loss: 10.377898\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.9055045133877494\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.9021639784287858\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 10.628485\n",
            "======================================================\n",
            "Accuracy: 0.9016796783336091\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 8.878321\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.8994004273359604\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.9050839928229234\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.9044701066116776\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.9047566712793573\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.9063798424927105\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.9008264580869225\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.9001214892044537\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.902945101154504\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.9009983483663652\n",
            "======================================================\n",
            "Finished Training\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eqBRyZxh-6_h",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "gp = ONN_THS(20, 5, 100, 10)"
      ],
      "execution_count": 0,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "mFFz2xWN6o_4",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 1000
        },
        "outputId": "ee7d349a-2efb-44b3-d44c-7a1976c903a3"
      },
      "source": [
        "for i in range(len(X_train_t)):\n",
        "    x = np.asarray([X_train_t[i, :]])\n",
        "    y = np.asarray([y_train_t[i]])\n",
        "\n",
        "    arm, exp = gp.predict(x)\n",
        "\n",
        "    if arm == y[0]:  \n",
        "      gp.partial_fit(x, y, exp)\n",
        "\n",
        "    if i % 2000 == 1999:\n",
        "      pred = []\n",
        "      print(\"======================================================\")\n",
        "      for i in range(len(X_test_t)):  \n",
        "        pred.append(gp.predict(np.asarray([X_test_t[i, :]]))[0])\n",
        "      print(\"Accuracy: \" + str(balanced_accuracy_score(y_test_t, pred)))\n",
        "      print(\"======================================================\")\n",
        "\n",
        "print('Finished Training')"
      ],
      "execution_count": 51,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.31641227 0.12936862 0.17681654 0.18663488 0.19076766]\n",
            "Training Loss: 4.9937906\n",
            "======================================================\n",
            "Accuracy: 0.5121877357778365\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.11614554 0.1702754  0.2646138  0.1885041  0.2604611 ]\n",
            "Training Loss: 9.42681\n",
            "======================================================\n",
            "Accuracy: 0.60000532007458\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 3.760187\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.6370007505514537\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.6198465258532846\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 1.579348\n",
            "======================================================\n",
            "Accuracy: 0.5121875529066693\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 9.70304\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.6330198059531864\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 10.2148285\n",
            "======================================================\n",
            "Accuracy: 0.7084962849287858\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.7265182796040035\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.7239405864519876\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.12717529\n",
            "======================================================\n",
            "Accuracy: 0.810706833334715\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 9.516822\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 2.7376916\n",
            "======================================================\n",
            "Accuracy: 0.8972895582247269\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 8.428832\n",
            "======================================================\n",
            "Accuracy: 0.90217251261295\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 11.3074665\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.9124533480316206\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.9035376710667149\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.9045845614725639\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.6930886  0.07199327 0.08099582 0.07299026 0.08093201]\n",
            "Training Loss: 12.486347\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.8146383732762196\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.8232097488414242\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 6.5098047\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.9154437246746809\n",
            "======================================================\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "WARNING: Set 'show_loss' to 'False' when not debugging. It will deteriorate the fitting performance.\n",
            "Alpha:[0.8399999 0.04      0.04      0.04      0.04     ]\n",
            "Training Loss: 0.0\n",
            "======================================================\n",
            "Accuracy: 0.9127977126517685\n",
            "======================================================\n",
            "Finished Training\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "X3IBIV0WAqKI",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": 0,
      "outputs": []
    }
  ]
}